{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e82e44e1",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d67b3443",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import torch\n",
    "import numpy as np\n",
    "import warnings\n",
    "from pathlib import Path\n",
    "warnings.simplefilter(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b64635de",
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_paths = glob.glob(\"demo_data/obs_*.npy\")\n",
    "pose_delta_paths = glob.glob(\"demo_data/pose_delta_*.npy\")\n",
    "device = torch.device(\"cuda:0\")\n",
    "\n",
    "obs = np.stack([np.load(obs_path) for obs_path in obs_paths])\n",
    "\n",
    "pose_delta = np.stack([np.load(pose_delta_path) \n",
    "                       for pose_delta_path in pose_delta_paths])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c3fc9b0",
   "metadata": {},
   "source": [
    "## LSeg Inference Usage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0092b55a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "rgb = np.transpose(obs[[2, 5, 20, 50], :3, :, :], (0, 2, 3, 1))\n",
    "rgb.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8528e6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load pre-trained model\n",
    "# Download LSeg checkpoint by following README instructions in lseg folder\n",
    "\n",
    "from home_robot.agent.perception.detection.lseg import load_lseg_for_inference\n",
    "\n",
    "checkpoint_path = (\n",
    "    Path().resolve().parent.parent.parent / \n",
    "    \"perception/detection/lseg/checkpoints/demo_e200.ckpt\"\n",
    ")\n",
    "device = torch.device(\"cuda:0\")\n",
    "model = load_lseg_for_inference(checkpoint_path, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa28b4f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Encode pixels to CLIP features\n",
    "\n",
    "pixel_features = model.encode(rgb)\n",
    "pixel_features.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c858a105",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Decode pixel CLIP features to text labels - we can introduce new labels\n",
    "# at inference time\n",
    "\n",
    "labels = [\"tree\", \"chair\", \"clock\", \"couch\", \"cushion\", \"lamp\", \"cabinet\", \"other\"]\n",
    "one_hot_predictions, visualizations = model.decode(pixel_features, labels)\n",
    "print(one_hot_predictions.shape, visualizations.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cdfe643",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize predictions\n",
    "\n",
    "f, axarr = plt.subplots(4, 2, figsize=(9, 14))\n",
    "axarr[0, 0].imshow(rgb[0].astype(int))\n",
    "axarr[0, 1].imshow(visualizations[0])\n",
    "axarr[1, 0].imshow(rgb[1].astype(int))\n",
    "axarr[1, 1].imshow(visualizations[1])\n",
    "axarr[2, 0].imshow(rgb[2].astype(int))\n",
    "axarr[2, 1].imshow(visualizations[2])\n",
    "axarr[3, 0].imshow(rgb[3].astype(int))\n",
    "axarr[3, 1].imshow(visualizations[3])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f9cc3fe",
   "metadata": {},
   "source": [
    "## Map Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "006c677e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from home_robot.agent.mapping.dense.semantic.vision_language_2d_semantic_map_state import VisionLanguage2DSemanticMapState\n",
    "from home_robot.agent.mapping.dense.semantic.vision_language_2d_semantic_map_module import VisionLanguage2DSemanticMapModule\n",
    "\n",
    "device = torch.device(\"cuda:0\")\n",
    "checkpoint_path = (\n",
    "    Path().resolve().parent.parent.parent / \n",
    "    \"perception/detection/lseg/checkpoints/demo_e200.ckpt\"\n",
    ")\n",
    "\n",
    "# State holds global and local map and sensor pose\n",
    "# See class definition for argument info\n",
    "semantic_map = VisionLanguage2DSemanticMapState(\n",
    "    device=device,\n",
    "    num_environments=1,\n",
    "    lseg_features_dim=512,\n",
    "    map_resolution=5,\n",
    "    map_size_cm=4800,\n",
    "    global_downscaling=2,\n",
    ")\n",
    "semantic_map.init_map_and_pose()\n",
    "\n",
    "# Module is responsible for updating the local and global maps and poses\n",
    "# See class definition for argument info\n",
    "semantic_map_module = VisionLanguage2DSemanticMapModule(\n",
    "    lseg_checkpoint_path=checkpoint_path,\n",
    "    lseg_features_dim=512,\n",
    "    frame_height=480,\n",
    "    frame_width=640,\n",
    "    camera_height=0.88,\n",
    "    hfov=79.0,\n",
    "    map_size_cm=4800,\n",
    "    map_resolution=5,\n",
    "    vision_range=100,\n",
    "    global_downscaling=2,\n",
    "    du_scale=4,\n",
    "    exp_pred_threshold=1.0,\n",
    "    map_pred_threshold=1.0,\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c05a33ec",
   "metadata": {},
   "source": [
    "## Map Update"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e7fa815",
   "metadata": {},
   "outputs": [],
   "source": [
    "temporal_batch_size = 1\n",
    "\n",
    "for t in range(0, len(obs), temporal_batch_size):\n",
    "    # print(f\"Update for steps {t} to {t + temporal_batch_size}\")\n",
    "    \n",
    "    seq_obs = torch.from_numpy(\n",
    "        obs[t:t + temporal_batch_size, :4, :, :]\n",
    "    ).unsqueeze(0).to(device)\n",
    "    seq_pose_delta = torch.from_numpy(\n",
    "        pose_delta[t:t + temporal_batch_size]\n",
    "    ).unsqueeze(0).to(device)\n",
    "    seq_dones = torch.tensor(\n",
    "        [False] * seq_obs.shape[1]\n",
    "    ).unsqueeze(0).to(device)\n",
    "    seq_update_global = torch.tensor(\n",
    "        [True] * seq_obs.shape[1]\n",
    "    ).unsqueeze(0).to(device)\n",
    "\n",
    "    (\n",
    "        seq_map_features,\n",
    "        semantic_map.local_map,\n",
    "        semantic_map.global_map,\n",
    "        seq_local_pose,\n",
    "        seq_global_pose,\n",
    "        seq_lmb,\n",
    "        seq_origins,\n",
    "    ) = semantic_map_module(\n",
    "        seq_obs,\n",
    "        seq_pose_delta,\n",
    "        seq_dones,\n",
    "        seq_update_global,\n",
    "        semantic_map.local_map,\n",
    "        semantic_map.global_map,\n",
    "        semantic_map.local_pose,\n",
    "        semantic_map.global_pose,\n",
    "        semantic_map.lmb,\n",
    "        semantic_map.origins,\n",
    "    )\n",
    "\n",
    "    semantic_map.local_pose = seq_local_pose[:, -1]\n",
    "    semantic_map.global_pose = seq_global_pose[:, -1]\n",
    "    semantic_map.lmb = seq_lmb[:, -1]\n",
    "    semantic_map.origins = seq_origins[:, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ef96554",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Global semantic map of shape \n",
    "# (batch_size, num_channels, M, M)\n",
    "# where num_channels = 4 + 512\n",
    "# 0: obstacle map\n",
    "# 1: explored area\n",
    "# 2: current agent location\n",
    "# 3: past agent locations\n",
    "# 4: number of cell updates\n",
    "# 5, 6, .., 5 + 512: CLIP map cell features\n",
    "semantic_map.global_map.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "904f78a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Local semantic map visualization\n",
    "\n",
    "import cv2\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "from home_robot.agent.perception.detection.coco_maskrcnn.coco_categories import (\n",
    "    coco_categories, coco_categories_color_palette,\n",
    ")\n",
    "\n",
    "map_color_palette = [\n",
    "    1.0,\n",
    "    1.0,\n",
    "    1.0,  # empty space\n",
    "    0.6,\n",
    "    0.6,\n",
    "    0.6,  # obstacles\n",
    "    0.95,\n",
    "    0.95,\n",
    "    0.95,  # explored area\n",
    "    0.96,\n",
    "    0.36,\n",
    "    0.26,  # visited area\n",
    "    *coco_categories_color_palette,\n",
    "]\n",
    "map_color_palette = [int(x * 255.0) for x in map_color_palette]\n",
    "num_sem_categories = len(coco_categories)\n",
    "\n",
    "semantic_categories_map = semantic_map.get_semantic_map(\n",
    "    0,\n",
    "    semantic_map_module.lseg,\n",
    "    labels=list(coco_categories.keys())[:-1] + [\"other\"]\n",
    ")\n",
    "\n",
    "obstacle_map = semantic_map.get_obstacle_map(0)\n",
    "explored_map = semantic_map.get_explored_map(0)\n",
    "visited_map = semantic_map.get_visited_map(0)\n",
    "\n",
    "semantic_categories_map += 4\n",
    "no_category_mask = semantic_categories_map == 4 + num_sem_categories - 1\n",
    "obstacle_mask = np.rint(obstacle_map) == 1\n",
    "explored_mask = np.rint(explored_map) == 1\n",
    "visited_mask = visited_map == 1\n",
    "semantic_categories_map[no_category_mask] = 0\n",
    "semantic_categories_map[np.logical_and(no_category_mask, explored_mask)] = 2\n",
    "semantic_categories_map[np.logical_and(no_category_mask, obstacle_mask)] = 1\n",
    "semantic_categories_map[visited_mask] = 3\n",
    "\n",
    "semantic_map_vis = Image.new(\"P\", semantic_categories_map.shape)\n",
    "semantic_map_vis.putpalette(map_color_palette)\n",
    "semantic_map_vis.putdata(semantic_categories_map.flatten().astype(np.uint8))\n",
    "semantic_map_vis = semantic_map_vis.convert(\"RGB\")\n",
    "semantic_map_vis = np.flipud(semantic_map_vis)\n",
    "plt.imshow(semantic_map_vis)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
